import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch.nn.functional as F
import torchvision.models as models

class MultiHeadSelfAttention2d(nn.Module):
    def __init__(self, in_channels, num_heads=4):

        super().__init__()
        self.num_heads = num_heads

        self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)

        d_k_total = in_channels // 8
        if d_k_total % num_heads != 0:
            raise ValueError("in_channels//8 must be divisible by num_heads")
        self.d_k = d_k_total // num_heads


        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)

        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        B, C, H, W = x.size()
        N = H * W


        query = self.query_conv(x)

        query = query.view(B, self.num_heads, self.d_k, N).permute(0, 1, 3, 2)


        key = self.key_conv(x)

        key = key.view(B, self.num_heads, self.d_k, N)


        energy = torch.matmul(query, key)
        attention = self.softmax(energy)


        value = self.value_conv(x)
        head_dim = C // self.num_heads
        value = value.view(B, self.num_heads, head_dim, N)


        out = torch.matmul(value, attention.transpose(-2, -1))


        out = out.view(B, C, H, W)


        out = self.gamma * out + x
        return out

class ResNet18SelfAttentionPredictor(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, num_heads=4):

        super().__init__()

        resnet = models.resnet18(weights=None)


        if in_channels != 3:
            resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2,
                                     padding=3, bias=False)

        resnet.maxpool = nn.Identity()


        def remove_stride(layer):
            for block in layer:
                if block.conv1.stride == (2, 2):
                    block.conv1.stride = (1, 1)
                if block.downsample is not None:
                    ds_conv = block.downsample[0]
                    if hasattr(ds_conv, 'stride') and ds_conv.stride == (2, 2):
                        ds_conv.stride = (1, 1)
        remove_stride(resnet.layer2)
        remove_stride(resnet.layer3)
        remove_stride(resnet.layer4)

        self.backbone = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4
        )


        self.attention = MultiHeadSelfAttention2d(in_channels=512, num_heads=num_heads)


        self.conv_out = nn.Conv2d(512, out_channels, kernel_size=1)

    def forward(self, x):

        B, _, H, W = x.shape
        feat = self.backbone(x)
        feat_att = self.attention(feat)
        out = self.conv_out(feat_att)
        out = F.interpolate(out, size=(H, W), mode='bilinear', align_corners=False)
        return out


class CNNIntensityDataset(Dataset):
    def __init__(self, folder_path, n_qubits, n_layers, variant_id, num_variants=20):

        self.samples = []
        self.n_qubits = n_qubits
        self.n_layers = n_layers
        self.num_variants = num_variants
        self.expected_param_count = 2 * n_qubits * n_layers


        layer_grid = np.linspace(0, 1, n_layers).reshape(n_layers, 1)
        col_grid = np.linspace(0, 1, 2 * n_qubits).reshape(1, 2 * n_qubits)
        coord_layer = np.repeat(layer_grid, 2 * n_qubits, axis=1)
        coord_col = np.repeat(col_grid, n_layers, axis=0)
        self.coord_input = np.stack([coord_layer, coord_col], axis=0).astype(np.float32)


        if num_variants <= 1:
            variant_norm = 0.0
        else:
            variant_norm = variant_id / (num_variants - 1)
        variant_channel = np.full((n_layers, 2 * n_qubits), variant_norm, dtype=np.float32)


        self.input_template = np.concatenate([self.coord_input, variant_channel[np.newaxis, ...]], axis=0)


        for i in range(20):
            freeze_file = os.path.join(folder_path, f"freeze_mask_{i}.npy")
            if not os.path.exists(freeze_file):
                print(f"Warning: File {freeze_file} not found, skipping.")
                continue
            freeze_array = np.load(freeze_file)
            num_iterations, num_params = freeze_array.shape
            if num_params != self.expected_param_count:
                raise ValueError(f"In file {freeze_file}, number of params {num_params} does not equal expected {self.expected_param_count}")

            intensity_vector = np.sum(freeze_array, axis=0).astype(np.float32)
            intensity_map = intensity_vector.reshape(n_layers, 2 * n_qubits)
            self.samples.append((self.input_template.copy(), intensity_map))

        print(f"[CNNIntensityDataset] Found {len(self.samples)} valid samples in folder {folder_path}.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_data, intensity_map = self.samples[idx]
        coord_tensor = torch.tensor(input_data)
        intensity_tensor = torch.tensor(intensity_map).unsqueeze(0)
        return coord_tensor, intensity_tensor


class AggregatedIntensityDataset(Dataset):
    def __init__(self, samples):

        self.samples = samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        input_data, intensity_map = self.samples[idx]
        coord_tensor = torch.tensor(input_data)
        intensity_tensor = torch.tensor(intensity_map).unsqueeze(0)
        return coord_tensor, intensity_tensor



def train_intensity_predictor_cnn(dataset, epochs=100, lr=1e-3, device="cpu",
                                  loss_plot_path="aggregated_training_loss.png", model=None):
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
    if model is None:
        model = ResNet18SelfAttentionPredictor(
            in_channels=3,
            out_channels=1,
            num_heads=4
        ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    loss_history = []
    model.train()

    for ep in range(epochs):
        total_loss = 0.0
        for batch_input, batch_label in dataloader:
            batch_input = batch_input.to(device)
            batch_label = batch_label.to(device).float()
            optimizer.zero_grad()
            pred = model(batch_input)
            loss = criterion(pred, batch_label)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataset)
        loss_history.append(avg_loss)
        if (ep + 1) % 10 == 0:
            print(f"Epoch {ep + 1}: MSE loss = {avg_loss:.6f}")

    plt.figure()
    plt.plot(range(1, epochs + 1), loss_history, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("Average MSE Loss")
    plt.title("Aggregated Training Loss Curve")
    plt.grid(True)
    plt.savefig(loss_plot_path)
    print(f"Training loss plot saved to {loss_plot_path}")
    plt.close()

    return model


def construct_input_tensor(n_qubits, n_layers, variant, num_variants=20, device="cpu"):
    layer_grid = np.linspace(0, 1, n_layers).reshape(n_layers, 1)
    col_grid = np.linspace(0, 1, 2 * n_qubits).reshape(1, 2 * n_qubits)
    coord_layer = np.repeat(layer_grid, 2 * n_qubits, axis=1)
    coord_col = np.repeat(col_grid, n_layers, axis=0)
    coord_input = np.stack([coord_layer, coord_col], axis=0).astype(np.float32)
    variant_norm = variant / (num_variants - 1) if num_variants > 1 else 0.0
    variant_channel = np.full((n_layers, 2 * n_qubits), variant_norm, dtype=np.float32)
    input_template = np.concatenate([coord_input, variant_channel[np.newaxis, ...]], axis=0)
    input_tensor = torch.tensor(input_template).unsqueeze(0).to(device)
    return input_tensor



def predict_intensity(model, n_qubits, n_layers, variant, device="cpu"):
    model.eval()
    input_tensor = construct_input_tensor(n_qubits, n_layers, variant, num_variants=20, device=device)
    with torch.no_grad():
        pred = model(input_tensor)
    return pred.squeeze(0).squeeze(0).cpu().numpy()


def draw_intensity_comparison(aggregated_pred, aggregated_gt, n_qubits, n_layers,
                              title_pred="Predicted Freeze Intensity",
                              title_gt="Ground Truth Freeze Intensity",
                              save_name=None, vmin=0, vmax=100, show_cz=True):
    pred_flat = aggregated_pred.flatten()
    gt_flat = aggregated_gt.flatten()
    expected_len = 2 * n_qubits * n_layers
    if pred_flat.shape[0] != expected_len or gt_flat.shape[0] != expected_len:
        raise ValueError("Intensity array length does not match the expected value.")
    cmap_pred = plt.get_cmap("Blues")
    cmap_gt = plt.get_cmap("Blues")
    norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))
    fig.suptitle("Intensity Comparison: Predicted vs Ground Truth", fontsize=20)

    def draw_on_ax(ax, count_1d, title, cmap_obj):
        ax.set_title(title, fontsize=18)
        for q in range(n_qubits):
            ax.plot([0, n_layers], [q, q], color="black", linewidth=1, alpha=0.5)
        if show_cz:
            for layer_idx in range(n_layers):
                for qubit_idx in range(n_qubits):
                    y1 = qubit_idx
                    y2 = (qubit_idx + 1) % n_qubits
                    ax.plot([layer_idx, layer_idx], [y1, y2], color="gray", linewidth=1, alpha=0.3)
        idx_global = 0
        for layer_idx in range(n_layers):
            rx_indices = np.arange(idx_global, idx_global + n_qubits)
            rx_vals = count_1d[rx_indices]
            x_rx = np.full(n_qubits, layer_idx + 0.25)
            y_rx = np.arange(n_qubits)
            ry_indices = np.arange(idx_global + n_qubits, idx_global + 2 * n_qubits)
            ry_vals = count_1d[ry_indices]
            x_ry = np.full(n_qubits, layer_idx + 0.75)
            y_ry = np.arange(n_qubits)
            idx_global += 2 * n_qubits
            colors_rx = [cmap_obj(norm(val)) for val in rx_vals]
            colors_ry = [cmap_obj(norm(val)) for val in ry_vals]
            ax.scatter(x_rx, y_rx, c=colors_rx, s=320, marker="o", edgecolors="k", linewidths=0.5)
            ax.scatter(x_ry, y_ry, c=colors_ry, s=320, marker="o", edgecolors="k", linewidths=0.5)
        ax.set_xlim(-0.5, n_layers + 0.5)
        ax.set_ylim(-1, n_qubits)
        ax.set_xlabel("Layer index", fontsize=16)
        ax.set_ylabel("Qubit index", fontsize=16)
        ax.set_xticks(range(n_layers))
        ax.set_yticks(range(n_qubits))
        ax.grid(False)
        sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_obj)
        sm.set_array([])
        cbar = plt.colorbar(sm, ax=ax)
        cbar.set_label("Counts", fontsize=14)

    draw_on_ax(ax1, pred_flat, title_pred, cmap_pred)
    draw_on_ax(ax2, gt_flat, title_gt, cmap_gt)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    if save_name is not None:
        plt.savefig(save_name)
        print(f"Intensity comparison figure saved to {save_name}")
        plt.close(fig)
    else:
        plt.show()
        plt.close(fig)

def main():
    noise_rate_target = 0
    qubits_range = range(5, 16)
    layers_range = range(5, 16)
    variant_range = [0]
    aggregated_samples = []

    for n_qubits in qubits_range:
        for n_layers in layers_range:
            n_params = 2 * n_qubits * n_layers
            for variant in variant_range:
                circuit_name_target = f"Dataset/tgtGD_variant{variant}_qubits{n_qubits}_layers{n_layers}"
                folder_tgt = f"./{circuit_name_target}_GD_noNoise_{n_qubits}_{n_params}_gd_apfa_tgt_{noise_rate_target}"
                if not os.path.exists(folder_tgt):
                    print(f"Folder {folder_tgt} does not exist. Skipping.")
                    continue
                ds = CNNIntensityDataset(
                    folder_path=folder_tgt,
                    n_qubits=n_qubits,
                    n_layers=n_layers,
                    variant_id=variant,
                    num_variants=len(variant_range)
                )
                aggregated_samples.extend(ds.samples)

    if not aggregated_samples:
        print("No samples found for training. Please check data paths.")
        return

    print(f"Total aggregated samples: {len(aggregated_samples)}")
    aggregated_dataset = AggregatedIntensityDataset(aggregated_samples)

    result_folder = 'Saved_model'
    os.makedirs(result_folder, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_plot_path = os.path.join(result_folder, "aggregated_training_loss.png")


    model = train_intensity_predictor_cnn(
        aggregated_dataset,
        epochs=1200,
        lr=1e-3,
        device=device,
        loss_plot_path=loss_plot_path,
        model=None
    )

    model_save_path = os.path.join(result_folder, "trained_resnet_attention.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Trained model saved to {model_save_path}")


    test_sample = random.choice(aggregated_samples)
    test_input, test_intensity = test_sample
    _, test_n_layers, test_width = test_input.shape
    test_n_qubits = test_width // 2
    test_variant = 0
    print(f"\nTesting on a randomly selected design: qubits={test_n_qubits}, layers={test_n_layers}, variant={test_variant}")

    predicted_intensity = predict_intensity(
        model, test_n_qubits, test_n_layers, test_variant, device=device
    )


    plt.figure(figsize=(8, 6))
    plt.imshow(predicted_intensity, cmap="Blues", aspect="auto")
    plt.title(f"Predicted Intensity (variant={test_variant}, qubits={test_n_qubits}, layers={test_n_layers})")
    plt.colorbar(label="Intensity")
    plt.show()

    comparison_fig_file = os.path.join(result_folder, "pred_vs_gt_intensity_sample.png")
    draw_intensity_comparison(
        aggregated_pred=predicted_intensity,
        aggregated_gt=test_intensity,
        n_qubits=test_n_qubits,
        n_layers=test_n_layers,
        title_pred="Predicted Freeze Intensity",
        title_gt="Ground Truth Freeze Intensity",
        save_name=comparison_fig_file,
        vmin=0,
        vmax=100,
        show_cz=True
    )
    print(f"Test intensity comparison saved to {comparison_fig_file}")


if __name__ == "__main__":
    main()
